一维 Thread Tile 并行优化
info
如前文所述,要尽可能的使用计算来掩盖访存延迟,因此可以使用一 个 thread 负责计算多个数据,本篇则使用一个 thread 在数据矩阵的行上进行分片,一个线程负责一小列的数据计算。
这个新的内核与前一个内核相似,但增加了一个新的内循环,用于计算每个线程的多个 C 条目。我们现在使用的 SMEM 缓存大小为 BM*BK + BN*BK = 64*8 + 64*8 = 1024
个浮点数,每个块总共为 4KB。下面是一个可视化效果,我用橙色和红色突出显示了两个线程以及它们在内循环中访问的值。
info
之所以缓存变成了长方形,是因为一个线程负责一小列的计算,本文使用的分片大小为 8,实际上负责该区域的线程数量是 。
在这个内核中,所有重要的更改都发生在内循环中。与之前相比,从 GMEM 到 SMEM 的加载基本相同。具体来看,我们分配了一个线程本地缓存 threadResults[TM]
用于寄存器文件。
// 为寄存器文件分配线程本地缓存
float threadResults[TM] = {0.0};
// 外循环遍历
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
// 填充SMEM缓存(与之前相同)
As[innerRowA * BK + innerColA] = A[innerRowA * K + innerColA];
Bs[innerRowB * BN + innerColB] = B[innerRowB * N + innerColB];
__syncthreads();
// 推进外循环的指针
A += BK;
B += BK * N;
// 计算每个线程的结果
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
// 我们将点积循环 放在外循环中,这有助于重用Bs,我们可以将其缓存在tmp变量中。
float Btmp = Bs[dotIdx * BN + threadCol];
for (uint resIdx = 0; resIdx < TM; ++resIdx) {
threadResults[resIdx] +=
As[(threadRow * TM + resIdx) * BK + dotIdx] * Btmp;
}
}
__syncthreads();
}
这个内核实现了约 8600 GFLOPs,比我们上一个内核快 2.2 倍。让我们计算一下在我们上一个内核中,每个线程执行了多少内存访问 (每个线程计算一个结果):
- GMEM:K/32 次外循环迭代 * 2 次加载
- SMEM:K/32 次外循环迭代 * BLOCKSIZE(=32) * 2 次加载
- 每个结果的内存访问:K/16 GMEM,K*2 SMEM
对于我们的新内核,其中每个线程计算了八个结果:
- GMEM:K/8 次外循环迭代 * 2 次加载
- SMEM:K/8 次外循环迭代 * BK(=8)*(1 + TM(=8))
- 每个结果的内存访问:K/32 GMEM,K*9/8 SMEM
正如预期的那样,我们现在每个指令的循环周期中由于内存压力造成的停顿明显减少。